# coding: utf-8


import os
import h5py
import numpy as np
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.applications import resnet50, inception_v3, xception
from keras.applications.resnet50 import ResNet50
from keras.applications.inception_v3 import InceptionV3
from keras.applications.xception import Xception
from keras.layers.core import Lambda
from keras.layers import Input, GlobalAveragePooling2D, GlobalMaxPooling2D, Dense, Dropout, Conv2D, MaxPooling2D
from keras.models import Model
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from keras.optimizers import adam, Adam
from sklearn.utils import shuffle
from keras.utils import to_categorical
from keras.utils.vis_utils import plot_model
from loss_history import LossHistory
from keras.callbacks import TensorBoard
from keras import backend as K
from keras.models import Model
from keras import layers
from keras.layers import Dense
from keras.layers import Input
from keras.layers import BatchNormalization
from keras.layers import Activation
from keras.layers import Conv2D
from keras.layers import SeparableConv2D
from keras.layers import MaxPooling2D
from keras.layers import GlobalAveragePooling2D
from keras.layers import GlobalMaxPooling2D
from keras.engine.topology import get_source_inputs
from keras.utils.data_utils import get_file
from keras.applications.imagenet_utils import decode_predictions
from keras.applications import imagenet_utils
preprocess = imagenet_utils.preprocess_input
from keras import optimizers, regularizers
from keras import backend as K
K.clear_session()


# -------------------------------------
# 定义参数
# -------------------------------------
EPOCHS = 20
BATCH_SIZE = 8  # batch size
IMAGE_SIZE = (512, 512)  # 图片尺寸
LEARNING_RATE = 1e-4  # 学习率
EARLY_STOPPING_PATIENCE = 3
REDUCE_LR_PATIENCE = 3
CLASS_NUM = 2  # 类别数
DROP_RATE = 0.2
WEIGHT_DECAY=1e-4
BN_NORM = True
DATA_FORMAT='channels_last'


# -------------------------------------
# 文件路径
# -------------------------------------
TRAIN_DATA_PATH = './data/train/'
VALID_DATA_PATH = './data/valid/'
RESULT_PATH = './result/'
MODEL_PATH = './model/'
OUTPUT_PATH = './output/'
MODEL_NAME = 'my-model.hdf5'


def my_preprocess(x):
    x = x / 255.0
    return x


# L2 正则
kernel_regularizer = regularizers.l2(WEIGHT_DECAY)
bias_regularizer = regularizers.l2(WEIGHT_DECAY)


def create_model():

    train_gen = ImageDataGenerator()
    valid_gen = ImageDataGenerator()
    train_generator = train_gen.flow_from_directory(TRAIN_DATA_PATH, IMAGE_SIZE, shuffle=True, batch_size=BATCH_SIZE, color_mode='grayscale')
    valid_generator = valid_gen.flow_from_directory(VALID_DATA_PATH, IMAGE_SIZE, batch_size=BATCH_SIZE, color_mode='grayscale')

    inputs = Input((*IMAGE_SIZE, 1))
    x_input = Lambda(my_preprocess)(inputs)

    # block1
    x = Conv2D(64, (3, 3), input_shape=(*IMAGE_SIZE, 1), strides=(1, 1), padding='same', activation='relu', name='block1_conv1')(x_input)
    x = Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block1_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block1_pool')(x)

    # block2
    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block2_conv1')(x)
    x = Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block2_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block2_pool')(x)

    # block3
    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block3_conv1')(x)
    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block3_conv2')(x)
    x = Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block3_conv3')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='block3_pool')(x)

    # side1
    x_side1 = SeparableConv2D(512, (3, 3), padding='same', use_bias=False, name='side1_sepconv1')(x)
    x_side1 = BatchNormalization(name='side1_bn1')(x_side1)
    x_side1 = Activation('relu', name='side1_act1')(x_side1)
    x_side1 = SeparableConv2D(512, (3, 3), padding='same', use_bias=False, name='side1_sepconv2')(x_side1)
    x_side1 = BatchNormalization(name='side1_bn2')(x_side1)
    x_side1 = Activation('relu', name='side1_act2')(x_side1)
    x_side1 = MaxPooling2D((2, 2), strides=(2, 2), padding='same', name='side1_pool')(x_side1)
    x_side1 = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='side1_sepconv3')(x_side1)
    x_side1 = BatchNormalization(name='side1_bn3')(x_side1)
    x_side1 = Activation('relu', name='side1_act3')(x_side1)
    x_side1 = SeparableConv2D(728, (3, 3), padding='same', activation='relu', name='side1_sepconv4')(x_side1)
    x_side1 = GlobalAveragePooling2D(name='side1_gap')(x_side1)

    # side2
    x_side2_1_1 = Conv2D(256, (1, 1), strides=(1, 1), padding='same', activation='relu', name='side2_1_conv1')(x)
    x_side2_1_2 = Conv2D(256, (1, 1), strides=(1, 1), padding='same', activation='relu', name='side2_2_conv1')(x)
    x_side2_1_2 = Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu', name='side2_2_conv2')(x_side2_1_2)
    x_side2_1_3 = Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu', name='side2_3_conv1')(x)
    x_side2_1_3 = Conv2D(256, (1, 1), strides=(1, 1), padding='same', activation='relu', name='side2_3_conv2')(x_side2_1_3)
    x_side2_1 = keras.layers.concatenate([x_side2_1_1, x_side2_1_2, x_side2_1_3])

    x_side2_2_1 = Conv2D(256, (1, 1), strides=(1, 1), padding='same', activation='relu', name='side3_1_conv1')(x_side2_1)
    x_side2_2_2 = Conv2D(256, (1, 1), strides=(1, 1), padding='same', activation='relu', name='side3_2_conv1')(x_side2_1)
    x_side2_2_2 = Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu', name='side3_2_conv2')(x_side2_2_2)
    x_side2_2_3 = Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu', name='side3_3_conv1')(x_side2_1)
    x_side2_2_3 = Conv2D(256, (1, 1), strides=(1, 1), padding='same', activation='relu', name='side3_3_conv2')(x_side2_2_3)

    x_side2_2 = keras.layers.concatenate([x_side2_2_1, x_side2_2_2, x_side2_2_3])
    x_side2 = GlobalAveragePooling2D(name='side2_gap')(x_side2_2)

     # block4
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block4_conv1')(x)
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block4_conv2')(x)
    x = Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu', name='block4_conv3')(x)

    x = GlobalAveragePooling2D(name='gap')(x)

    x = keras.layers.concatenate([x, x_side1, x_side2])

    x = Dropout(DROP_RATE, name='dropout1')(x)
    predictions = Dense(CLASS_NUM, activation='softmax', name='dense1')(x)
    model = Model(inputs=inputs, outputs=predictions)
    model.summary()
    plot_model(model, to_file=os.path.join(RESULT_PATH, 'my_model.png'), show_shapes=True)

    check_point = ModelCheckpoint(monitor='val_loss',
                                  filepath=os.path.join(MODEL_PATH, MODEL_NAME),
                                  verbose=1,
                                  save_best_only=True,
                                  save_weights_only=False,
                                  mode='auto')

    # early stopping
    early_stopping = EarlyStopping(monitor='val_loss', patience=EARLY_STOPPING_PATIENCE, verbose=0, mode='auto')

    # reduce lr
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=REDUCE_LR_PATIENCE, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)

    # 创建一个 LossHistory 实例
    history = LossHistory()

    # compile
    model.compile(optimizer=adam(lr=LEARNING_RATE), loss='binary_crossentropy', metrics=['accuracy'])

    # fit
    model.fit_generator(
        train_generator,
        steps_per_epoch=train_generator.samples // BATCH_SIZE,
        epochs=EPOCHS,
        validation_data=valid_generator,
        validation_steps=valid_generator.samples // BATCH_SIZE,
        callbacks=[check_point, early_stopping, history]
    )

    # 绘制 loss 曲线和 batch 曲线
    history.loss_plot('batch', os.path.join(RESULT_PATH, 'my_loss_batch.png'))
    history.acc_plot('batch', os.path.join(RESULT_PATH, 'my_batch.png'))
    history.loss_plot('epoch', os.path.join(RESULT_PATH, 'my_loss_epoch.png'))
    history.acc_plot('epoch', os.path.join(RESULT_PATH, 'my_acc_epoch.png'))
    K.clear_session()

if __name__ == '__main__':
    create_model()
